#!/usr/bin/env python3
# src/t3/plateau.py
# Plateau finder used by run_t3.py
# Returns (i0, i1, ok, stats) where stats includes:
#   A_theta, rmse_flat, R2_flat, rmse, R2, n_bins

from __future__ import annotations
import numpy as np, math
from typing import Optional, Tuple, Dict

# -------- helpers --------

def _smooth_P(P: np.ndarray, w: int = 5) -> np.ndarray:
    """Odd-window median smoother; stabilizes local dP/db without biasing level."""
    P = np.asarray(P, dtype=float)
    if w <= 1 or P.size < 3:
        return P
    k = int(w) | 1  # ensure odd
    pad = k // 2
    Pp = np.pad(P, (pad, pad), mode='edge')
    return np.array([np.median(Pp[i:i+k]) for i in range(len(P))], dtype=float)

def _linfit(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float, float]:
    """y = a x + b  →  (a, b, rmse, R2)."""
    x = np.asarray(x, dtype=float); y = np.asarray(y, dtype=float)
    n = x.size
    if n < 2 or not np.all(np.isfinite(x)) or not np.all(np.isfinite(y)):
        return (np.nan, np.nan, np.nan, np.nan)
    A = np.vstack([x, np.ones_like(x)]).T
    a, b = np.linalg.lstsq(A, y, rcond=None)[0]
    yhat = a*x + b
    resid = y - yhat
    rmse = float(np.sqrt(np.mean(resid**2)))
    ss_tot = float(np.sum((y - np.mean(y))**2))
    ss_res = float(np.sum(resid**2))
    R2 = 1.0 - ss_res/ss_tot if ss_tot > 0 else 0.0
    return float(a), float(b), rmse, float(R2)

# -------- core API expected by run_t3.py --------

def select_flat_window(
    b: np.ndarray,
    P: np.ndarray,
    slope_abs_max: float,
    min_bins: int,
    bmin: float,
    bmax: float,
) -> Tuple[Optional[int], Optional[int], bool, Dict[str, float]]:
    """
    RETURNS: (i0, i1, ok, stats)

    - b, P are per-bin arrays (P usually = gamma_t * b * [factor])
    - slope_abs_max, min_bins, [bmin,bmax] define acceptance
    - stats includes 'A_theta' (median plateau level) and flatness metrics
    """
    b = np.asarray(b, dtype=float)
    P = np.asarray(P, dtype=float)

    # Window & finite mask
    mask = np.isfinite(b) & np.isfinite(P) & (b >= bmin) & (b <= bmax)
    b2 = b[mask]
    # Smooth only for window selection (keeps level unbiased)
    P2 = _smooth_P(P[mask], w=5)

    n = b2.size
    if n < max(2, min_bins):
        stats = {
            "A_theta": math.nan,
            "rmse": math.nan, "R2": math.nan,
            "rmse_flat": math.nan, "R2_flat": math.nan,
            "n_bins": float(n),
        }
        return (None, None, False, stats)

    # Brute scan for longest flat window
    best_len = -1
    best = (None, None, math.inf, -math.inf)  # (i0, i1, rmse, R2)

    for i0 in range(0, n - min_bins + 1):
        for i1 in range(i0 + min_bins - 1, n):
            a, _, rmse, R2 = _linfit(b2[i0:i1+1], P2[i0:i1+1])
            if not np.isfinite(a):
                continue
            if abs(a) <= slope_abs_max:
                length = i1 - i0 + 1
                if (length > best_len) or (length == best_len and (rmse < best[2] or (math.isclose(rmse, best[2]) and R2 > best[3]))):
                    best_len = length
                    best = (i0, i1, rmse, R2)

    if best_len < min_bins or best[0] is None:
        stats = {
            "A_theta": math.nan,
            "rmse": math.nan, "R2": math.nan,
            "rmse_flat": math.nan, "R2_flat": math.nan,
            "n_bins": float(n),
        }
        return (None, None, False, stats)

    i0, i1, rmse, R2 = best

    # Plateau amplitude A_theta as the median level over the selected window.
    # Use the same P2 (smoothed) that was used for selection to avoid noise bias.
    A_theta = float(np.median(P2[i0:i1+1])) if i1 >= i0 else math.nan

    stats = {
        "A_theta": A_theta,
        "rmse": float(rmse),
        "R2": float(R2),
        "rmse_flat": float(rmse),
        "R2_flat": float(R2),
        "n_bins": float(best_len),
    }
    return (int(i0), int(i1), True, stats)
